package com.xxd.algo.myself.leetcode;

import java.util.HashSet;
import java.util.List;

/**
 * @author: XiaoDong.Xie
 * @create: 2021-06-01 10:14
 * @description:
 */
public class Problem_0139_WordBreak {


    public static boolean wordBreak(String s, List<String> wordDict) {
        return process(s, 0, new HashSet<>(wordDict)) != 0;
    }

    private static int process(String s, int index, HashSet<String> wordDict) {
        if (index == s.length()) {
            return 1;
        }

        int ways = 0;
        for (int end = index; end < s.length(); end++) {
            String pre = s.substring(index, end + 1);
            if (wordDict.contains(pre)) {
                if (process(s, end + 1, wordDict) == 1) {
                    return 1;
                }
                ways += 0;
            }
        }

        return ways;
    }

    public static boolean wordBreak2(String s, List<String> wordDict) {
        HashSet<String> set = new HashSet<>(wordDict);
        int N = s.length();

        int[] dp = new int[N + 1];
        dp[N] = 1;

        for (int index = N - 1; index >= 0; index--) {
            int ways = 0;
            for (int end = index; end < N; end++) {
                String pre = s.substring(index, end + 1);
                if (set.contains(pre)) {
                    ways += dp[end + 1];
                }
            }

            dp[index] = ways;
        }

        return dp[0] != 0;
    }

    public static class Node {
        boolean end;
        Node[] nexts;

        public Node() {
            end = false;
            nexts = new Node[26];
        }
    }

    public static boolean wordBreak3(String s, List<String> wordDict) {
        Node root = new Node();
        for (String str : wordDict) {
            char[] chs = str.toCharArray();

            Node node = root;
            int index = 0;
            for (int i = 0; i < chs.length; i++) {
                index = chs[i] - 'a';
                if (node.nexts[index] == null) {
                    node.nexts[index] = new Node();
                }

                node = node.nexts[index];
            }

            node.end = true;
        }


        int N = s.length();
        int[] dp = new int[N + 1];
        dp[N] = 1;
        char[] str = s.toCharArray();
        for (int index = N - 1; index >= 0; index--) {
            Node cur = root;
            for (int end = index; end < N; end++) {
                cur = cur.nexts[str[end] - 'a'];
                if (cur != null) {
                    break;
                }
                if (cur.end) {
                    dp[index] += dp[end + 1];
                }
            }
        }

        return dp[0] != 0;
    }

}
